#include <gtest/gtest.h>
#include <mockcpp/mockcpp.hpp>

#define protected public
#include "common/types.h"
#include "graph/op/array_defs.h"
#include "graph/op/math_defs.h"
#include "graph/op/const_defs.h"
#include "graph/op/nn_defs.h"
#include "graph/op/detection_defs.h"
#include "framework/graph/debug/ge_op_types.h"
#include "framework/graph/utils/tensor_utils.h"
#include "framework/graph/utils/graph_utils.h"
#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/utils/attr_utils.h"
#include "framework/graph/core/cgraph/graph_spec.h"
#include "framework/graph/core/cgraph/graph_modifier.h"
#include "framework/graph/core/cgraph/graph_finder.h"
#include "framework/graph/core/edge/edge.h"
#include "framework/graph/core/optimize/fusion/pattern_define.h"
#include "framework/graph/core/optimize/fusion/pattern_pass.h"
#include "framework/graph/core/optimize/schedule/node_optimize_scheduler.h"
#include "framework/graph/core/optimize/schedule/node_pass_list.h"
#include "infra/base/securestl.h"
#undef protected
#undef private

using namespace std;
using namespace hiai;
using namespace ge;

class UTEST_fusion_pattern_define_and_replace : public testing::Test {
protected:
    void SetUp()
    {
    }
    void TearDown()
    {
        GlobalMockObject::verify();
    }

protected:
    ge::Node* AddNode(ComputeGraphPtr graph, const string& name, const string& type, int32_t out_anchors_num = 1,
        int32_t in_anchors_num = 1)
    {
        TensorDesc tensor_desc;
        OpDescPtr opdesc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, type));
        for (int32_t i = 0; i < in_anchors_num; i++) {
            opdesc->AddInputDesc(tensor_desc);
        }
        for (int32_t i = 0; i < out_anchors_num; i++) {
            opdesc->AddOutputDesc(tensor_desc);
        }

        ge::Node* node = graph->ROLE(GraphModifier).AddNode(opdesc);
        return node;
    }

    void SetWeightsFloat(ge::Node* node, float w)
    {
        float data[] = {w};
        TensorDesc tensor_desc(Shape(), FORMAT_NCHW, DT_FLOAT);
        TensorPtr tensor = make_shared<Tensor>(tensor_desc, (uint8_t*)data, sizeof(data));
        vector<TensorPtr> weights = {tensor};
        OpDescUtils::SetWeights(*node, weights);
    }

    Node& PatternOutNode()
    {
        return *patternOutNode_;
    }

    Node* AnyNode()
    {
        return any_;
    }

    ComputeGraphPtr CreateGraph()
    {
        ComputeGraphPtr graph = ge::ComputeGraph::Make("test_graph");

        auto any = AddNode(graph, "op_any", hiai::op::Reshape::TYPE);
        auto greater = AddNode(graph, "op_greater", hiai::op::Greater::TYPE, 1, 2);
        auto cast = AddNode(graph, "op_cast", hiai::op::CastT::TYPE);
        auto mul = AddNode(graph, "op_mul", hiai::op::Mul::TYPE, 1, 2);
        auto relu = AddNode(graph, "op_relu", RELU);

        SetWeightsFloat(greater, 0.2);

        graph->ROLE(GraphModifier).AddEdge({*any, 0}, {*greater, 0});
        graph->ROLE(GraphModifier).AddEdge({*any, 0}, {*mul, 1});
        graph->ROLE(GraphModifier).AddEdge({*greater, 0}, {*cast, 0});
        graph->ROLE(GraphModifier).AddEdge({*cast, 0}, {*mul, 0});
        graph->ROLE(GraphModifier).AddEdge({*mul, 0}, {*relu, 0});

        any_ = any;
        patternOutNode_ = mul;

        return graph;
    }

private:
    Node* patternOutNode_ = nullptr;
    Node* any_ = nullptr;
};

namespace {
enum NodeId : Id { ANY, CONST0, GREATER, MUL, CAST };

static PatternDefine GREATER_CAST_MUL_PATTERN(
    {
        {Id {NodeId::ANY}},
        {Id {NodeId::CONST0}, Type {hiai::op::Const::TYPE}},
        {Id {NodeId::GREATER}, Type {hiai::op::Greater::TYPE}, NodeInputs {NodeId::ANY, NodeId::CONST0}},
        {Id {NodeId::CAST}, Type {hiai::op::CastT::TYPE}, NodeInputs {NodeId::GREATER}},
        {Id {NodeId::MUL}, Type {hiai::op::Mul::TYPE}, NodeInputs {NodeId::CAST, NodeId::ANY}},
    },
    PatternInputs {NodeId::GREATER, NodeId::MUL}, PatternOutput {NodeId::MUL});
} // namespace

bool validater(ge::Node& node)
{
    return true;
}

class TestNodePass : public INodePass {
public:
    OptimizeStatus Run(ge::Node& node, ge::ComputeGraph& graph) override
    {
        return OptimizeStatus::OPTIMIZE_CHANGED;
    }
    std::vector<std::string> AttentionNodeTypes() override
    {
        return {hiai::op::BNInference::TYPE};
    }
};

TEST_F(UTEST_fusion_pattern_define_and_replace, test_pattern_match)
{
    ComputeGraphPtr graph = CreateGraph();

    auto patternMapping = GREATER_CAST_MUL_PATTERN.Match(PatternOutNode(), *graph);
    GREATER_CAST_MUL_PATTERN.AttentionTypes();
    EXPECT_EQ(true, patternMapping != nullptr);
    EXPECT_EQ(true, patternMapping->Node(NodeId::ANY) == AnyNode());

    auto passList = hiai::NodePassList::Make();
    std::unique_ptr<INodePass> testpass = hiai::make_unique_nothrow<TestNodePass>();
    EXPECT_NE(testpass, nullptr);
    passList->Add(std::move(testpass));
    Node* node = graph->ROLE(GraphFinder).FindNode("op_any");
    passList->Optimize(*node, *graph, validater);
    const char* type = "Const";
    passList->Attention(type);
    auto nodePassScheduler = hiai::NodeOptimizeScheduler::Make(*passList, 1);
    Status ret = nodePassScheduler->Schedule(*graph);
    EXPECT_EQ(ret, hiai::SUCCESS);
}

class TestPermuteNodePass : public INodePass {
public:
    OptimizeStatus Run(ge::Node& node, ge::ComputeGraph& graph) override
    {
        // first create Permute Op
        ge::OpDescPtr p2Op = std::make_shared<ge::OpDesc>("permute2", std::string {hiai::op::Permute::TYPE});
        ge::TensorDesc p2In(ge::Shape({1, 320, 160, 3}));
        ge::TensorDesc p2Out(ge::Shape({1, 3, 320, 160}));
        p2Op->AddInputDesc(p2In);
        p2Op->AddOutputDesc(p2Out);
        ge::AttrUtils::SetListInt(p2Op, hiai::op::Permute::order, std::vector<int64_t>({0, 3, 1, 2}));
        ge::Node* op = graph.ROLE(GraphModifier).AddNode(p2Op);

        // then delete Permute Op
        if (graph.ROLE(GraphModifier).RemoveNode(*op) != hiai::SUCCESS) {
            return OptimizeStatus::OPTIMIZE_FAILED;
        }

        return OptimizeStatus::OPTIMIZE_CHANGED;
    }

    std::vector<std::string> AttentionNodeTypes() override
    {
        return {hiai::op::Permute::TYPE};
    }
};

TEST_F(UTEST_fusion_pattern_define_and_replace, test_pass_create_op_then_delete_op)
{
    ge::ComputeGraphPtr graph = ge::ComputeGraph::Make("test");

    ge::OpDescPtr dataOp = std::make_shared<ge::OpDesc>("data", ge::DATA);
    ge::TensorDesc dataOut(ge::Shape({1, 3, 320, 160}));
    dataOp->AddOutputDesc(dataOut);
    ge::NodePtr data = graph->AddNode(dataOp);

    ge::OpDescPtr p1Op = std::make_shared<ge::OpDesc>("permute1", std::string {hiai::op::Permute::TYPE});
    ge::TensorDesc p1In(ge::Shape({1, 3, 320, 160}));
    ge::TensorDesc p1Out(ge::Shape({1, 320, 160, 3}));
    p1Op->AddInputDesc(p1In);
    p1Op->AddOutputDesc(p1Out);
    ge::AttrUtils::SetListInt(p1Op, hiai::op::Permute::order, std::vector<int64_t>({0, 2, 3, 1}));
    ge::NodePtr p1 = graph->AddNode(p1Op);

    ge::OpDescPtr netoutOpDesc = std::make_shared<ge::OpDesc>("netout", ge::NETOUTPUT);
    ge::TensorDesc netoutTensor(ge::Shape({1, 320, 160, 3}));
    netoutOpDesc->AddInputDesc(netoutTensor);
    netoutOpDesc->AddOutputDesc(netoutTensor);
    ge::NodePtr netout = graph->AddNode(netoutOpDesc);

    graph->ROLE(GraphModifier).AddEdge({*data, 0}, {*p1, 0});
    graph->ROLE(GraphModifier).AddEdge({*p1, 0}, {*netout, 0});

    auto passList = hiai::NodePassList::Make();
    passList->Add(std::move(hiai::make_unique_nothrow<TestPermuteNodePass>()));

    auto nodePassScheduler = hiai::NodeOptimizeScheduler::Make(*passList, 1);
    Status status = nodePassScheduler->Schedule(*graph);
    EXPECT_EQ(status, hiai::SUCCESS);
    EXPECT_EQ(graph->ROLE(GraphSpec).NodesNum(), 3);
}